from collections import Counter
import pybedtools
from matplotlib import patches
from matplotlib.colors import LinearSegmentedColormap

from pylab import *
from numpy import *
from scipy.stats import spearmanr, pearsonr

from rpy2 import robjects
import rpy2.robjects.numpy2ri
robjects.numpy2ri.activate()
from rpy2.robjects.packages import importr
deseq = importr('DESeq2')
print("Using DESeq2 version %s" % deseq.__version__)
glmGamPoi = importr('glmGamPoi')

import warnings
warnings.filterwarnings('error')


timepoints = (0, 1, 4, 12, 24, 96)

def read_ppvalues():
    filename = 'enhancers.deseq.txt'
    print("Reading %s" % filename)
    handle = open(filename)
    line = next(handle)
    words = line.split()
    assert words[0] == "enhancer"
    assert words[1] == "00hr_basemean"
    assert words[2] == "00hr_log2fc"
    assert words[3] == "00hr_pvalue"
    assert words[4] == "01hr_basemean"
    assert words[5] == "01hr_log2fc"
    assert words[6] == "01hr_pvalue"
    assert words[7] == "04hr_basemean"
    assert words[8] == "04hr_log2fc"
    assert words[9] == "04hr_pvalue"
    assert words[10] == "12hr_basemean"
    assert words[11] == "12hr_log2fc"
    assert words[12] == "12hr_pvalue"
    assert words[13] == "24hr_basemean"
    assert words[14] == "24hr_log2fc"
    assert words[15] == "24hr_pvalue"
    assert words[16] == "96hr_basemean"
    assert words[17] == "96hr_log2fc"
    assert words[18] == "96hr_pvalue"
    assert words[19] == "all_basemean"
    assert words[20] == "all_log2fc"
    assert words[21] == "all_pvalue"
    basemeans = {}
    ppvalues = {}
    for line in handle:
        words = line.split()
        assert len(words) == 22
        name = words[0]
        basemean = float(words[19])
        basemeans[name] = basemean
        log2fc = float(words[20])
        pvalue = float(words[21])
        ppvalue = -log10(pvalue) * sign(log2fc)
        ppvalues[name] = ppvalue
    handle.close()
    return ppvalues

def read_expression_data(ppvalues):
    names = list(ppvalues.keys())
    filename = "enhancers.expression.txt"
    print("Reading", filename)
    handle = open(filename)
    line = next(handle)
    words = line.split()
    assert words[0] == 'enhancer'
    assert words[1] == 'HiSeq_t00_r1'
    assert words[2] == 'HiSeq_t00_r2'
    assert words[3] == 'HiSeq_t00_r3'
    assert words[4] == 'HiSeq_t01_r1'
    assert words[5] == 'HiSeq_t01_r2'
    assert words[6] == 'HiSeq_t04_r1'
    assert words[7] == 'HiSeq_t04_r2'
    assert words[8] == 'HiSeq_t04_r3'
    assert words[9] == 'HiSeq_t12_r1'
    assert words[10] == 'HiSeq_t12_r2'
    assert words[11] == 'HiSeq_t12_r3'
    assert words[12] == 'HiSeq_t24_r1'
    assert words[13] == 'HiSeq_t24_r2'
    assert words[14] == 'HiSeq_t24_r3'
    assert words[15] == 'HiSeq_t96_r1'
    assert words[16] == 'HiSeq_t96_r2'
    assert words[17] == 'HiSeq_t96_r3'
    assert words[18] == 'CAGE_00_hr_A'
    assert words[19] == 'CAGE_00_hr_C'
    assert words[20] == 'CAGE_00_hr_G'
    assert words[21] == 'CAGE_00_hr_H'
    assert words[22] == 'CAGE_01_hr_A'
    assert words[23] == 'CAGE_01_hr_C'
    assert words[24] == 'CAGE_01_hr_G'
    assert words[25] == 'CAGE_04_hr_C'
    assert words[26] == 'CAGE_04_hr_E'
    assert words[27] == 'CAGE_12_hr_A'
    assert words[28] == 'CAGE_12_hr_C'
    assert words[29] == 'CAGE_24_hr_C'
    assert words[30] == 'CAGE_24_hr_E'
    assert words[31] == 'CAGE_96_hr_A'
    assert words[32] == 'CAGE_96_hr_C'
    assert words[33] == 'CAGE_96_hr_E'
    hiseq_indices = [list() for timepoint in timepoints]
    cage_indices = [list() for timepoint in timepoints]
    for index, word in enumerate(words[1:]):
        condition, replicate = word.rsplit("_", 1)
        experiment, timepoint = condition.split("_", 1)
        if experiment == "HiSeq":
            assert replicate in ("r1", "r2", "r3")
            assert timepoint.startswith("t")
            timepoint = int(timepoint[1:])
            hiseq_indices[timepoints.index(timepoint)].append(index)
        elif experiment == "CAGE":
            assert replicate in "ACEGH"
            assert timepoint.endswith("_hr")
            timepoint = int(timepoint[:-3])
            cage_indices[timepoints.index(timepoint)].append(index)
        else:
            raise Exception("Unknown experiment %s" % experiment)
    for i, timepoint in enumerate(timepoints):
        hiseq_indices[i] = array(hiseq_indices[i])
    for i, timepoint in enumerate(timepoints):
        cage_indices[i] = array(cage_indices[i])
    n = sum([len(hiseq_indices[i])+len(cage_indices[i]) for i, timepoint in enumerate(timepoints)])
    assert n == 33
    i = 0
    n = sum([len(hiseq_indices[i])+len(cage_indices[i]) for i, timepoint in enumerate(timepoints)])
    assert n == 33
    data = []
    for i, line in enumerate(handle):
        words = line.split()
        assert len(words) == 34
        name = words[0]
        dataset, locus = name.split("|")
        assert dataset in ("FANTOM5", "CAGE", "HiSeq", "Both")
        assert name == names[i]
        row = []
        for word in words[1:34]:
            count1, count2 = word.split(",")
            count = int(count1) + int(count2)
            row.append(count)
        data.append(row)
    handle.close()
    data = array(data)
    return cage_indices, hiseq_indices, data

def estimate_normalization_factors(counts, cage_indices, hiseq_indices):
    n, m = shape(counts)
    conditions_dataset = [None] * m
    conditions_timepoint = [None] * m
    for i, timepoint in enumerate(timepoints):
        timepoint = str(timepoint)
        for index in cage_indices[i]:
            conditions_dataset[index] = "CAGE"
            conditions_timepoint[index] = timepoint
        for index in hiseq_indices[i]:
            conditions_dataset[index] = "HiSeq"
            conditions_timepoint[index] = timepoint
    metadata = {'dataset': robjects.StrVector(conditions_dataset),
                'timepoint': robjects.StrVector(conditions_timepoint),
               }
    dataframe = robjects.DataFrame(metadata)
    design = robjects.Formula("~ timepoint + dataset")
    dds = deseq.DESeqDataSetFromMatrix(countData=counts,
                                       colData=dataframe,
                                       design=design)
    estimateSizeFactors = robjects.r['estimateSizeFactors']
    dds = estimateSizeFactors(dds)
    sizeFactors = robjects.r['sizeFactors']
    factors = sizeFactors(dds)
    return factors

def make_figure_scatter(cage_indices, hiseq_indices, counts, factors, ppvalues):
    cmap = LinearSegmentedColormap.from_list("mycmap", ['blue', 'lightgray', 'red'])
    ppvalues = array(list(ppvalues.values()))
    tpm = counts / factors
    total = mean(sum(tpm, 0))  # should be 1 million
    tpm *= 1.e6 / total  # rescale to get 1 million
    m = len(cage_indices)
    assert m == len(hiseq_indices)
    n = len(counts)
    cage_tpm = zeros((n, m))
    hiseq_tpm = zeros((n, m))
    for i, indices in enumerate(cage_indices):
        cage_tpm[:, i] = mean(tpm[:, indices], 1)
    for i, indices in enumerate(hiseq_indices):
        hiseq_tpm[:, i] = mean(tpm[:, indices], 1)
    cage_tpm = mean(cage_tpm, 1)
    hiseq_tpm = mean(hiseq_tpm, 1)
    assert len(cage_tpm) == n
    assert len(hiseq_tpm) == n
    print("Plotting a scatter plot with %d points" % n)
    fig = figure(figsize=(5,5))
    axes([0.20, 0.20, 0.55, 0.55])
    scatter(cage_tpm, hiseq_tpm, c=ppvalues, s=3, vmin=-12, vmax=+12, cmap=cmap)
    xscale('log')
    yscale('log')
    fig.axes[0].set_aspect('equal', adjustable='box')
    xmin, xmax = xlim()
    ymin, ymax = ylim()
    minimum = min(xmin, ymin)
    maximum = max(xmax, ymax)
    xlim(minimum, maximum)
    ylim(minimum, maximum)
    xticks(fontsize=8)
    yticks(fontsize=8)
    xlabel("Long capped RNA (CAGE libraries),\naverage expression [tpm]", fontsize=8)
    ylabel("Short capped RNA (single-end libraries),\naverage expression [tpm]", fontsize=8)
    cax = axes([0.80, 0.20, 0.03, 0.55])
    cb = colorbar(cax=cax)
    cb.set_label("$-\\log(p) \\times \\mathrm{sign}$", fontsize=8)
    cb.ax.tick_params(labelsize=8)
    r, p = spearmanr(cage_tpm, hiseq_tpm)
    print("CAGE average expression vs HiSeq average expression: Spearman correlation across enhancers is %.2f (p = %g)" % (r, p))
    r, p = pearsonr(log(cage_tpm+1), log(hiseq_tpm+1))
    print("CAGE average expression vs HiSeq average expression: Pearson correlation across enhancers is %.2f (p = %g)" % (r, p))
    filename = "figure_enhancer_scatterplot.svg"
    print("Saving figure as %s" % filename)
    savefig(filename)
    filename = "figure_enhancer_scatterplot.png"
    print("Saving figure as %s" % filename)
    savefig(filename)


ppvalues = read_ppvalues()

cage_indices, hiseq_indices, data = read_expression_data(ppvalues)

ppvalue_threshold = -log10(0.05)
counts = Counter()
for name, row in zip(ppvalues, data):
    key = []
    if sum([sum(row[ii]) for ii in hiseq_indices]) > 1e-20:
        key.append("hiseq_expressed")
    else:
        key.append("hiseq_not_expressed")
    if sum([sum(row[ii]) for ii in cage_indices]) > 1e-20:
        key.append("cage_expressed")
    else:
        key.append("cage_not_expressed")
    if ppvalues[name] > +ppvalue_threshold:
        key.append("hiseq_significant")
    else:
        key.append("hiseq_not_significant")
    if ppvalues[name] < -ppvalue_threshold:
        key.append("cage_significant")
    else:
        key.append("cage_not_significant")
    key = tuple(key)
    counts[key] += 1

for key in counts:
    print(key, counts[key])

circle_hiseq_expressed = Circle((-1., 0), 3, color='r', alpha=0.2)
circle_cage_expressed = Circle((+1., 0), 3, color='b', alpha=0.2)
circle_hiseq_significant = Circle((-2, 0), 1.75, color='r', alpha=0.2)
circle_cage_significant = Circle((+2, 0), 1.75, color='b', alpha=0.2)

f = figure()
ax = subplot(111)

ax.add_patch(circle_hiseq_expressed)
ax.add_patch(circle_cage_expressed)
ax.add_patch(circle_hiseq_significant)
ax.add_patch(circle_cage_significant)

axis('square')
xlim(-5,5)
ylim(-5,5)


number1 = counts[('hiseq_expressed', 'cage_expressed', 'hiseq_not_significant', 'cage_not_significant')]
text(0, -2, str(number1), horizontalalignment='center', verticalalignment='center')
number2 = counts[('hiseq_not_expressed', 'cage_expressed', 'hiseq_not_significant', 'cage_not_significant')]
text(+1.95, -2.25, str(number2), horizontalalignment='center', verticalalignment='center')
number3 = counts[('hiseq_expressed', 'cage_not_expressed', 'hiseq_not_significant', 'cage_not_significant')]
text(-1.95, -2.25, str(number3), horizontalalignment='center', verticalalignment='center')
number4 = counts[('hiseq_expressed', 'cage_expressed', 'hiseq_significant', 'cage_not_significant')]
text(-1.2, 0, str(number4), horizontalalignment='center', verticalalignment='center')
number5 = counts[('hiseq_expressed', 'cage_expressed', 'hiseq_not_significant', 'cage_significant')]
text(+1.2, 0, str(number5), horizontalalignment='center', verticalalignment='center')
number6 = counts[('hiseq_expressed', 'cage_not_expressed', 'hiseq_significant', 'cage_not_significant')]
text(-2.9, 0, str(number6), horizontalalignment='center', verticalalignment='center')
number7 = counts[('hiseq_not_expressed', 'cage_expressed', 'hiseq_not_significant', 'cage_significant')]
text(+2.9, 0, str(number7), horizontalalignment='center', verticalalignment='center')
text(-2, 3, 'HiSeq', color='red', horizontalalignment='center', verticalalignment='bottom')
text(+2, 3, 'CAGE', color='blue', horizontalalignment='center', verticalalignment='bottom')

hiseq_expressed_patch = patches.Patch(color='red', alpha=0.2, label='Expressed as\nshort capped RNAs')
cage_expressed_patch = patches.Patch(color='blue', alpha=0.2, label='Expressed as\nlong capped RNAs')
hiseq_significant_patch = patches.Patch(color='red', alpha=0.4, label='Significantly enriched\nin short capped RNA\n(single-end) libraries')
cage_significant_patch = patches.Patch(color='blue', alpha=0.4, label='Significantly enriched\nin long capped RNA\n(CAGE) libraries')

legend(handles=[hiseq_expressed_patch, hiseq_significant_patch, cage_expressed_patch, cage_significant_patch], fontsize=8, loc='lower center', ncol=2)

title("Enhancers")

axis('off')

filename = "figure_enhancer_venn_diagram.svg"
print("Saving Venn diagram to", filename)
savefig(filename)

filename = "figure_enhancer_venn_diagram.png"
print("Saving Venn diagram to", filename)
savefig(filename)

factors = estimate_normalization_factors(data, cage_indices, hiseq_indices)

make_figure_scatter(cage_indices, hiseq_indices, data, factors, ppvalues)

filename = "figure_enhancer_scatterplot.svg"
print("Saving scatter plot to", filename)
savefig(filename)

filename = "figure_enhancer_scatterplot.png"
print("Saving scatter plot to", filename)
savefig(filename)

print("Total number of gene-associated peaks: %d" % len(ppvalues))
print("Expressed in HiSeq and CAGE: %d" % (number1 + number4 + number5))
print("Expressed in HiSeq only: %d" % (number3 + number6))
print("Expressed in CAGE only: %d" % (number2 + number7))
print("Significantly higher expression in HiSeq: %d" % (number4 + number6))
print("Significantly higher expression in CAGE: %d" % (number5 + number7))

